<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>Graph Attention Networks &mdash; Graph4NLP v0.4.1 documentation</title><link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
    <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  <script id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
        <script src="../../_static/jquery.js"></script>
        <script src="../../_static/underscore.js"></script>
        <script src="../../_static/doctools.js"></script>
        <script src="../../_static/language_data.js"></script>
        <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    <script src="../../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../../genindex.html" />
    <link rel="search" title="Search" href="../../search.html" />
    <link rel="next" title="GraphSAGE" href="graphsage.html" />
    <link rel="prev" title="Gated Graph Neural Networks" href="ggnn.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
            <a href="../../index.html" class="icon icon-home"> Graph4NLP
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption"><span class="caption-text">Get Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../welcome/installation.html">Install Graph4NLP</a></li>
</ul>
<p class="caption"><span class="caption-text">User Guide</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../graphdata.html">Chapter 1. Graph Data</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dataset.html">Chapter 2. Dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../construction.html">Chapter 3. Graph Construction</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../gnn.html">Chapter 4. Graph Encoder</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="gcn.html">Graph Convolutional Networks</a></li>
<li class="toctree-l2"><a class="reference internal" href="ggnn.html">Gated Graph Neural Networks</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Graph Attention Networks</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#gat-module-construction-function">GAT Module Construction Function</a></li>
<li class="toctree-l3"><a class="reference internal" href="#gat-module-forward-function">GAT Module Forward Function</a></li>
<li class="toctree-l3"><a class="reference internal" href="#gatlayer-construction-function">GATLayer Construction Function</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="graphsage.html">GraphSAGE</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../decoding.html">Chapter 5. Decoder</a></li>
<li class="toctree-l1"><a class="reference internal" href="../classification.html">Chapter 6. Classification</a></li>
<li class="toctree-l1"><a class="reference internal" href="../evaluation.html">Chapter 7. Evaluations and Loss components</a></li>
</ul>
<p class="caption"><span class="caption-text">Module API references</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../modules/data.html">graph4nlp.data</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules/datasets.html">graph4nlp.datasets</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules/graph_construction.html">graph4nlp.graph_construction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules/graph_embedding.html">graph4nlp.graph_embedding</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules/prediction.html">graph4nlp.prediction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules/loss.html">graph4nlp.loss</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules/evaluation.html">graph4nlp.evaluation</a></li>
</ul>
<p class="caption"><span class="caption-text">Tutorials</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/text_classification.html">Text Classification Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/semantic_parsing.html">Semantic Parsing Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/math_word_problem.html">Math Word Problem Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/knowledge_graph_completion.html">Knowledge Graph Completion Tutorial</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../index.html">Graph4NLP</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
          <li><a href="../gnn.html">Chapter 4. Graph Encoder</a> &raquo;</li>
      <li>Graph Attention Networks</li>
      <li class="wy-breadcrumbs-aside">
            <a href="../../_sources/guide/gnn/gat.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <div class="section" id="graph-attention-networks">
<span id="guide-gat"></span><h1>Graph Attention Networks<a class="headerlink" href="#graph-attention-networks" title="Permalink to this headline">¶</a></h1>
<p>The Graph Attention Network (<a class="reference external" href="https://arxiv.org/abs/1710.10903">GAT</a>) aims to learn edge weights for the input binary adjacency matrix by introducing
the multi-head attention mechanism to the GNN architecture when performing message passing.
We provide high level APIs to users to easily define a multi-layer GAT model. Besides, we support both
regular GAT and bidirectional versions including <a class="reference external" href="https://arxiv.org/abs/1808.07624">GAT-BiSep</a>
and <a class="reference external" href="https://arxiv.org/abs/1908.04942">GAT-BiFuse</a>. The math operation of GAT is represented as below:</p>
<blockquote>
<div><div class="math notranslate nohighlight">
\[h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}\]</div>
<p>where <span class="math notranslate nohighlight">\(\alpha_{ij}\)</span> is the attention score bewteen node <span class="math notranslate nohighlight">\(i\)</span> and
node <span class="math notranslate nohighlight">\(j\)</span>.</p>
</div></blockquote>
<p>Where the attention matrix is computed as below:</p>
<blockquote>
<div><div class="math notranslate nohighlight">
\[\begin{split}\alpha_{ij}^{l} = \mathrm{softmax_i} (e_{ij}^{l})\\
e_{ij}^{l} = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)\end{split}\]</div>
</div></blockquote>
<div class="section" id="gat-module-construction-function">
<h2>GAT Module Construction Function<a class="headerlink" href="#gat-module-construction-function" title="Permalink to this headline">¶</a></h2>
<p>The construction function performs the following steps:</p>
<ol class="arabic simple">
<li><p>Set options.</p></li>
<li><p>Register learnable parameters or submodules (<code class="docutils literal notranslate"><span class="pre">GATLayer</span></code>).</p></li>
</ol>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">GAT</span><span class="p">(</span><span class="n">GNNBase</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span>
        <span class="n">direction_option</span><span class="o">=</span><span class="s1">&#39;bi_sep&#39;</span><span class="p">,</span> <span class="n">feat_drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">attn_drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span>
        <span class="n">residual</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">GAT</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span> <span class="o">=</span> <span class="n">direction_option</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">gat_layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span>
        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">&gt;</span> <span class="mi">0</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">hidden_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">hidden_size</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">heads</span> <span class="o">=</span> <span class="p">[</span><span class="n">heads</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="c1"># input projection</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">gat_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">GATLayer</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">heads</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
                                    <span class="n">direction_option</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span><span class="p">,</span>
                                    <span class="n">feat_drop</span><span class="o">=</span><span class="n">feat_drop</span><span class="p">,</span> <span class="n">attn_drop</span><span class="o">=</span><span class="n">attn_drop</span><span class="p">,</span>
                                    <span class="n">negative_slope</span><span class="o">=</span><span class="n">negative_slope</span><span class="p">,</span> <span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
                                    <span class="n">activation</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="n">allow_zero_in_degree</span><span class="p">))</span>

        <span class="c1"># hidden layers</span>
        <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
            <span class="c1"># due to multi-head, the input_size = hidden_size * num_heads</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">gat_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">GATLayer</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">[</span><span class="n">l</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">heads</span><span class="p">[</span><span class="n">l</span> <span class="o">-</span> <span class="mi">1</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">[</span><span class="n">l</span><span class="p">],</span>
                                        <span class="n">heads</span><span class="p">[</span><span class="n">l</span><span class="p">],</span> <span class="n">direction_option</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span><span class="p">,</span>
                                        <span class="n">feat_drop</span><span class="o">=</span><span class="n">feat_drop</span><span class="p">,</span> <span class="n">attn_drop</span><span class="o">=</span><span class="n">attn_drop</span><span class="p">,</span>
                                        <span class="n">negative_slope</span><span class="o">=</span><span class="n">negative_slope</span><span class="p">,</span> <span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
                                        <span class="n">activation</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="n">allow_zero_in_degree</span><span class="p">))</span>
        <span class="c1"># output projection</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">gat_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">GATLayer</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">heads</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">input_size</span><span class="p">,</span>
                                    <span class="n">output_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">direction_option</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span><span class="p">,</span>
                                    <span class="n">feat_drop</span><span class="o">=</span><span class="n">feat_drop</span><span class="p">,</span> <span class="n">attn_drop</span><span class="o">=</span><span class="n">attn_drop</span><span class="p">,</span> <span class="n">negative_slope</span><span class="o">=</span><span class="n">negative_slope</span><span class="p">,</span>
                                    <span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="n">allow_zero_in_degree</span><span class="p">))</span>
</pre></div>
</div>
<p>In construction function, one first needs to set the number of GAT layers and the data dimensions (i.e., <code class="docutils literal notranslate"><span class="pre">input_size</span></code>, <code class="docutils literal notranslate"><span class="pre">hidden_size</span></code>, <code class="docutils literal notranslate"><span class="pre">output_size</span></code>).
For GAT, <code class="docutils literal notranslate"><span class="pre">heads</span></code> is also an important parameter. One should also specify <code class="docutils literal notranslate"><span class="pre">direction_option</span></code> which determines whether to use
unidirectional (i.e., undirected) or bidirectional (i.e., <cite>bi_sep</cite> and <cite>bi_fuse</cite>) version of GAT.</p>
</div>
<div class="section" id="gat-module-forward-function">
<h2>GAT Module Forward Function<a class="headerlink" href="#gat-module-forward-function" title="Permalink to this headline">¶</a></h2>
<p>In NN module, <code class="docutils literal notranslate"><span class="pre">forward()</span></code> function does the actual message passing and computation. <code class="docutils literal notranslate"><span class="pre">forward()</span></code> takes a parameter <code class="docutils literal notranslate"><span class="pre">GraphData</span></code> as input.
The updated node embedding will be stored back into the node field <code class="docutils literal notranslate"><span class="pre">node_emb</span></code> of GraphData and the final output is the GraphData.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
    <span class="n">feat</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">node_features</span><span class="p">[</span><span class="s1">&#39;node_feat&#39;</span><span class="p">]</span>
    <span class="n">dgl_graph</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">to_dgl</span><span class="p">()</span>

    <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span> <span class="o">==</span> <span class="s1">&#39;bi_sep&#39;</span><span class="p">:</span>
        <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">feat</span><span class="p">,</span> <span class="n">feat</span><span class="p">]</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">h</span> <span class="o">=</span> <span class="n">feat</span>

    <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
        <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gat_layers</span><span class="p">[</span><span class="n">l</span><span class="p">](</span><span class="n">dgl_graph</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span> <span class="o">==</span> <span class="s1">&#39;bi_sep&#39;</span><span class="p">:</span>
            <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">each</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">each</span> <span class="ow">in</span> <span class="n">h</span><span class="p">]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># output projection</span>
    <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gat_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">](</span><span class="n">dgl_graph</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>

    <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">direction_option</span> <span class="o">==</span> <span class="s1">&#39;bi_sep&#39;</span><span class="p">:</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="p">[</span><span class="n">each</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">each</span> <span class="ow">in</span> <span class="n">logits</span><span class="p">]</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">graph</span><span class="o">.</span><span class="n">node_features</span><span class="p">[</span><span class="s1">&#39;node_emb&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">logits</span>

    <span class="k">return</span> <span class="n">graph</span>
</pre></div>
</div>
</div>
<div class="section" id="gatlayer-construction-function">
<h2>GATLayer Construction Function<a class="headerlink" href="#gatlayer-construction-function" title="Permalink to this headline">¶</a></h2>
<p>To make the utilization of GAT more felxbible, we also provide the low-level implementation of GAT layer. Similarly to high-level API, users can specify <code class="docutils literal notranslate"><span class="pre">direction_option</span></code> which determines whether to use
unidirectional (i.e., undirected) or bidirectional (i.e., <cite>bi_sep</cite> and <cite>bi_fuse</cite>) GAT.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">GATLayer</span><span class="p">(</span><span class="n">GNNLayerBase</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">direction_option</span><span class="o">=</span><span class="s1">&#39;bi_sep&#39;</span><span class="p">,</span> <span class="n">feat_drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
        <span class="n">attn_drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">residual</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    <span class="nb">super</span><span class="p">(</span><span class="n">GATLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">num_heads</span> <span class="o">&gt;</span>  <span class="mi">1</span> <span class="ow">and</span> <span class="n">residual</span><span class="p">:</span>
        <span class="n">residual</span> <span class="o">=</span> <span class="kc">False</span>
        <span class="kn">import</span> <span class="nn">warnings</span>
        <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The residual option must be False when num_heads &gt; 1&quot;</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">direction_option</span> <span class="o">==</span> <span class="s1">&#39;undirected&#39;</span><span class="p">:</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">UndirectedGATLayerConv</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">feat_drop</span><span class="o">=</span><span class="n">feat_drop</span><span class="p">,</span>
                            <span class="n">attn_drop</span><span class="o">=</span><span class="n">attn_drop</span><span class="p">,</span> <span class="n">negative_slope</span><span class="o">=</span><span class="n">negative_slope</span><span class="p">,</span> <span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
                            <span class="n">activation</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="n">allow_zero_in_degree</span><span class="p">)</span>
    <span class="k">elif</span> <span class="n">direction_option</span> <span class="o">==</span> <span class="s1">&#39;bi_sep&#39;</span><span class="p">:</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">BiSepGATLayerConv</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">feat_drop</span><span class="o">=</span><span class="n">feat_drop</span><span class="p">,</span>
                            <span class="n">attn_drop</span><span class="o">=</span><span class="n">attn_drop</span><span class="p">,</span> <span class="n">negative_slope</span><span class="o">=</span><span class="n">negative_slope</span><span class="p">,</span> <span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
                            <span class="n">activation</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="n">allow_zero_in_degree</span><span class="p">)</span>
    <span class="k">elif</span> <span class="n">direction_option</span> <span class="o">==</span> <span class="s1">&#39;bi_fuse&#39;</span><span class="p">:</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">BiFuseGATLayerConv</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">feat_drop</span><span class="o">=</span><span class="n">feat_drop</span><span class="p">,</span>
                            <span class="n">attn_drop</span><span class="o">=</span><span class="n">attn_drop</span><span class="p">,</span> <span class="n">negative_slope</span><span class="o">=</span><span class="n">negative_slope</span><span class="p">,</span> <span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
                            <span class="n">activation</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span> <span class="n">allow_zero_in_degree</span><span class="o">=</span><span class="n">allow_zero_in_degree</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;Unknown `direction_option` value: </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">direction_option</span><span class="p">))</span>
</pre></div>
</div>
</div>
</div>


           </div>
          </div>
          <footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
        <a href="ggnn.html" class="btn btn-neutral float-left" title="Gated Graph Neural Networks" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
        <a href="graphsage.html" class="btn btn-neutral float-right" title="GraphSAGE" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
    </div>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2020, Graph4AI Group.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>